v0.1.2-alpha.1 release, featuring MTP.#16
Conversation
881232c to
767113b
Compare
There was a problem hiding this comment.
Pull request overview
This PR prepares the v0.1.2-alpha.1 release by adding Multi-Token Prediction (MTP) support to the DeepSeek v3.2 “show hands” pipeline, alongside weight-conversion utilities and updated generation tooling/docs.
Changes:
- Add MTP E2E execution path and speculative decoding logic (new MTP layer wrapper + generator support).
- Introduce/extend weight conversion utilities for MMA-friendly layouts and additional parameter initializers.
- Update CLI generation script + README to document MTP usage and benchmarks; switch library loading to
torch.ops.load_library.
Reviewed changes
Copilot reviewed 13 out of 14 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| requirements-ci.txt | Removed CI requirements file (currently referenced by GitHub Actions). |
| python/utils.py | Cast cosine-similarity inputs to float for stable numeric behavior. |
| python/models/utils.py | Add swizzle enum + swizzle-map related comments. |
| python/models/preprocess/weight_utils.py | Add/reshape multiple weight converters (MMA swizzles, unproj/down allreduce, qkv conversions). |
| python/models/deepseek_v3_2/params.py | Add new param containers + large DSA/MTP initializer; expand TempVars for MTP. |
| python/models/deepseek_v3_2/dsa_show_hands.py | Add on-demand conversion + MTP integration into weight loading and generation flow. |
| python/models/deepseek_v3_2/dsa_mtp_e2e_show_hands.py | New MTP E2E wrapper around TileRT ops with helper accessors. |
| python/models/deepseek_v3_2/init.py | Add package init. |
| python/models/base.py | Extend kernel type options; adjust to_tilert_weights return type. |
| python/generate.py | Add MTP CLI flags + basic perf harness and cleanup. |
| python/init.py | Switch shared library loading to torch.ops.load_library. |
| README.md | Add MTP documentation and release notes / usage examples. |
Comments suppressed due to low confidence (1)
requirements-ci.txt:1
- Deleting requirements-ci.txt will break the GitHub Actions lint workflow, which currently runs
pip install -r requirements-ci.txt(see.github/workflows/lint.yml). Either restore this file or update the workflow to install an existing dependency set (e.g., requirements-dev.txt /pip install .[dev]).
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def tilert_to_tilert_112sm_mma( | ||
| mat_in: torch.Tensor, | ||
| mat_scale_in: torch.Tensor, | ||
| ) -> torch.Tensor: |
There was a problem hiding this comment.
tilert_to_tilert_112sm_mma is annotated as returning torch.Tensor, but it actually returns a tuple (weights, scales). This will fail mypy and is misleading for callers (and is already being unpacked as a tuple elsewhere). Change the return annotation to tuple[torch.Tensor, torch.Tensor].
| ) -> torch.Tensor: | |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| all_times_np = np.array(all_times) | ||
| for token_num in range(100, 300, 100): | ||
| mean_time = np.mean(all_times_np[..., 5:token_num]) | ||
| speed = 1 / mean_time | ||
| out_str = ( | ||
| f"**Perf@{token_num}: {speed:.3f} tokens/s & {(mean_time * 1000):.3f} ms**" | ||
| ) | ||
| print(out_str) |
There was a problem hiding this comment.
all_times_np = np.array(all_times) will produce a 1D object array if per-iteration time_list lengths differ (common if EOS is hit early). In that case, all_times_np[..., 5:token_num] slices the outer array, not each run’s token timings, and np.mean can return incorrect results or NaN. Compute per-run metrics first (e.g., iterate all_times and average time_list[5:token_num] when available) and then aggregate those scalars across runs.
| all_times_np = np.array(all_times) | |
| for token_num in range(100, 300, 100): | |
| mean_time = np.mean(all_times_np[..., 5:token_num]) | |
| speed = 1 / mean_time | |
| out_str = ( | |
| f"**Perf@{token_num}: {speed:.3f} tokens/s & {(mean_time * 1000):.3f} ms**" | |
| ) | |
| print(out_str) | |
| for token_num in range(100, 300, 100): | |
| per_run_means = [] | |
| for time_list in all_times: | |
| # Require enough tokens to compute stats from token 5 up to token_num | |
| if len(time_list) > 5 and len(time_list) >= token_num: | |
| slice_times = time_list[5:token_num] | |
| if slice_times: | |
| per_run_means.append(float(np.mean(slice_times))) | |
| if per_run_means: | |
| mean_time = float(np.mean(per_run_means)) | |
| speed = 1 / mean_time | |
| out_str = ( | |
| f"**Perf@{token_num}: {speed:.3f} tokens/s & {(mean_time * 1000):.3f} ms**" | |
| ) | |
| print(out_str) |
|
|
||
| print("Prompt:", prompt) | ||
| print("Completion:") | ||
| completion = generator.generate(prompt) |
There was a problem hiding this comment.
Same issue as above: ShowHandsGenerator.generate() returns a tuple now, so this MTP example should unpack the returned (text, time_list, accepted_counts) instead of treating it as a single string.
| completion = generator.generate(prompt) | |
| completion, time_list, accepted_counts = generator.generate(prompt) | |
| print(completion) |
|
|
||
| </details> | ||
|
|
||
| This example highlights how MTP enables TileRT to efficiently generate longer outputs by accepting multiple tokens per decoding step, while preserving the same Python API interface. |
There was a problem hiding this comment.
This sentence claims the “same Python API interface” is preserved, but ShowHandsGenerator.generate()’s return type has changed from str to a tuple in this PR. Either adjust the wording or keep the return type stable and expose performance metrics via an optional flag/side channel.
| This example highlights how MTP enables TileRT to efficiently generate longer outputs by accepting multiple tokens per decoding step, while preserving the same Python API interface. | |
| This example highlights how MTP enables TileRT to efficiently generate longer outputs by accepting multiple tokens per decoding step. |
|
|
||
| @staticmethod | ||
| def num_params() -> int: | ||
| return 1 |
There was a problem hiding this comment.
EmbeddingParams registers two tensors (embedding + freqs_cis) but num_params() returns 1. This makes num_params() inconsistent with the actual parameter list and can break callers that rely on the count (e.g., slicing or continuous-storage packing). Update num_params() to 2 (or stop registering freqs_cis here if it is not intended to be part of the params list).
| return 1 | |
| return 2 |
| ) | ||
| res_dicts[weight_key] = weights | ||
| res_dicts[scale_key] = scales | ||
| state_dicts[weight_key] = None | ||
| elif "unproj_scales" in key: # skip unprojo_allreduce op:: scales |
There was a problem hiding this comment.
This function is typed as dict[str, torch.Tensor], but it assigns None into state_dicts (e.g., state_dicts[weight_key] = None). This will cause mypy failures and breaks the declared type contract. If the goal is to drop references for memory, consider collecting keys to delete after the loop, using del state_dicts[key], or widen the type to dict[str, torch.Tensor | None] and ensure None values are never re-consumed.
| f"Invalid compute kernel type: {compute_kernel_type}, \ | ||
| must be one of bf16, fp8, fp8mma." | ||
| ) | ||
| self.compute_kernel_type = compute_kernel_type | ||
|
|
There was a problem hiding this comment.
The ValueError message is built using a backslash-escaped newline inside the f-string, which results in awkward whitespace in the final message. Prefer a single-line message or explicit string concatenation within parentheses so the error text is clean and predictable.
| f"Invalid compute kernel type: {compute_kernel_type}, \ | |
| must be one of bf16, fp8, fp8mma." | |
| ) | |
| self.compute_kernel_type = compute_kernel_type | |
| f"Invalid compute kernel type: {compute_kernel_type}, must be one of bf16, fp8, fp8mma." | |
| ) | |
| self.compute_kernel_type = compute_kernel_type |
| try: | ||
| return ctypes.CDLL(str(lib_path)) | ||
| torch.ops.load_library(str(lib_path)) | ||
| return lib_path | ||
| except Exception as e: |
There was a problem hiding this comment.
_load_library now uses torch.ops.load_library(...) and returns lib_path, but the docstring says it returns “the loaded library”. Since the return value is unused (module-level call doesn’t capture it), consider returning None and updating the docstring/annotation accordingly, or return a meaningful handle consistently.
| print("Prompt:", prompt) | ||
| print("Completion:") | ||
| completion: generator.generate(prompt) | ||
| completion = generator.generate(prompt) |
There was a problem hiding this comment.
ShowHandsGenerator.generate() now returns a tuple (text, time_list, accepted_counts) (see python/models/deepseek_v3_2/dsa_show_hands.py), but this example assigns it to completion as if it were a string. Update the snippet to unpack the first element (or keep generate() returning str for backward compatibility).
| completion = generator.generate(prompt) | |
| completion, _, _ = generator.generate(prompt) |
Release TileRT v0.1.2-alpha.1 with initial support for Multi-Token Prediction (MTP).
With mtp=3, decoding reaches up to 590 tokens/s on synthetic workloads and ~440 tokens/s on real generation tasks.